In [1]:
from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
Random Seed:  999
Out[1]:
<torch._C.Generator at 0x7f271c1646d0>
In [2]:
# Root directory for dataset
dataroot = "resized_data"

# Number of workers for dataloader
workers = 2

# Batch size during training
batch_size = 64

# Spatial size of training images. All images will be resized to this
#   size using a transformer.
image_size = 64

# Number of channels in the training images. For color images this is 3
nc = 3

# Size of z latent vector (i.e. size of generator input)
nz = 100

# Size of feature maps in generator
ngf = 64

# Size of feature maps in discriminator
ndf = 64

# Number of training epochs
num_epochs = 100

# Learning rate for optimizers
lr = 0.002

# Beta1 hyperparam for Adam optimizers
beta1 = 0.5

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1
In [3]:
# We can use an image folder dataset the way we have it setup.
# Create the dataset
dataset = dset.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

# Plot some training images
real_batch = next(iter(dataloader))
#real_batch是一个列表
#第一个元素real_batch[0]是[128,3,64,64]的tensor,就是标准的一个batch的4D结构:128张图,3个通道,64长,64宽
#第二个元素real_batch[1]是第一个元素的标签,有128个label值全为0
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

#这个函数能让图片显示
#plt.show() 
Out[3]:
<matplotlib.image.AxesImage at 0x7f262fded580>
In [4]:
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
In [5]:
class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            #输入100,输出64*8,核函数是4*4
            
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)
In [6]:
# Create the generator
netG = Generator(ngpu).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netG.apply(weights_init)

# Print the model
print(netG)
Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)
In [7]:
class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)
In [8]:
# Create the Discriminator
netD = Discriminator(ngpu).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netD = nn.DataParallel(netD, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netD.apply(weights_init)

# Print the model
print(netD)
Discriminator(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (12): Sigmoid()
  )
)
In [9]:
# Initialize BCELoss function
criterion = nn.BCELoss()
4
# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# Establish convention for real and fake labels during training
real_label = 1
fake_label = 0

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
In [10]:
# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, device=device).to(torch.float32)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Add the gradients from the all-real and all-fake batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1
Starting Training Loop...
[0/100][0/93]	Loss_D: 1.6999	Loss_G: 40.1906	D(x): 0.5276	D(G(z)): 0.5628 / 0.0000
[0/100][50/93]	Loss_D: 1.4494	Loss_G: 1.2917	D(x): 0.5064	D(G(z)): 0.4271 / 0.3520
[1/100][0/93]	Loss_D: 1.7604	Loss_G: 1.6836	D(x): 0.3206	D(G(z)): 0.2162 / 0.3798
[1/100][50/93]	Loss_D: 1.3249	Loss_G: 1.5570	D(x): 0.5753	D(G(z)): 0.4049 / 0.3874
[2/100][0/93]	Loss_D: 2.0170	Loss_G: 1.7017	D(x): 0.7297	D(G(z)): 0.5953 / 0.3089
[2/100][50/93]	Loss_D: 1.6342	Loss_G: 4.4252	D(x): 0.8885	D(G(z)): 0.6054 / 0.0254
[3/100][0/93]	Loss_D: 1.7646	Loss_G: 2.4523	D(x): 0.6408	D(G(z)): 0.5972 / 0.1851
[3/100][50/93]	Loss_D: 1.6389	Loss_G: 1.1261	D(x): 0.5377	D(G(z)): 0.5174 / 0.4063
[4/100][0/93]	Loss_D: 1.8466	Loss_G: 1.4022	D(x): 0.4411	D(G(z)): 0.4102 / 0.3977
[4/100][50/93]	Loss_D: 1.2757	Loss_G: 1.5958	D(x): 0.5665	D(G(z)): 0.4087 / 0.2647
[5/100][0/93]	Loss_D: 1.5907	Loss_G: 2.3479	D(x): 0.6143	D(G(z)): 0.5716 / 0.1478
[5/100][50/93]	Loss_D: 1.4556	Loss_G: 1.7111	D(x): 0.5621	D(G(z)): 0.5104 / 0.2095
[6/100][0/93]	Loss_D: 1.2832	Loss_G: 2.0977	D(x): 0.6070	D(G(z)): 0.5048 / 0.1411
[6/100][50/93]	Loss_D: 1.6797	Loss_G: 2.2991	D(x): 0.6827	D(G(z)): 0.6372 / 0.1312
[7/100][0/93]	Loss_D: 2.1469	Loss_G: 3.1729	D(x): 0.8637	D(G(z)): 0.8295 / 0.0800
[7/100][50/93]	Loss_D: 1.4546	Loss_G: 2.0007	D(x): 0.6119	D(G(z)): 0.4956 / 0.2341
[8/100][0/93]	Loss_D: 1.2431	Loss_G: 2.3678	D(x): 0.5243	D(G(z)): 0.2565 / 0.1988
[8/100][50/93]	Loss_D: 1.0137	Loss_G: 1.9053	D(x): 0.6517	D(G(z)): 0.3890 / 0.2040
[9/100][0/93]	Loss_D: 1.3229	Loss_G: 2.1623	D(x): 0.6117	D(G(z)): 0.3121 / 0.2136
[9/100][50/93]	Loss_D: 1.0480	Loss_G: 1.1604	D(x): 0.5511	D(G(z)): 0.2769 / 0.3729
[10/100][0/93]	Loss_D: 1.4377	Loss_G: 2.7750	D(x): 0.6887	D(G(z)): 0.5091 / 0.1207
[10/100][50/93]	Loss_D: 0.6916	Loss_G: 2.7626	D(x): 0.7994	D(G(z)): 0.3329 / 0.0718
[11/100][0/93]	Loss_D: 1.4013	Loss_G: 1.6210	D(x): 0.5177	D(G(z)): 0.3185 / 0.2509
[11/100][50/93]	Loss_D: 1.0651	Loss_G: 1.9566	D(x): 0.7137	D(G(z)): 0.4583 / 0.1701
[12/100][0/93]	Loss_D: 0.9124	Loss_G: 3.3337	D(x): 0.7890	D(G(z)): 0.4327 / 0.0670
[12/100][50/93]	Loss_D: 1.4064	Loss_G: 1.3487	D(x): 0.6426	D(G(z)): 0.5376 / 0.3072
[13/100][0/93]	Loss_D: 1.1960	Loss_G: 1.8684	D(x): 0.4900	D(G(z)): 0.2665 / 0.2087
[13/100][50/93]	Loss_D: 1.0647	Loss_G: 2.0331	D(x): 0.5749	D(G(z)): 0.3427 / 0.1631
[14/100][0/93]	Loss_D: 2.6204	Loss_G: 1.0550	D(x): 0.1890	D(G(z)): 0.0435 / 0.4636
[14/100][50/93]	Loss_D: 0.6243	Loss_G: 2.5740	D(x): 0.7683	D(G(z)): 0.2733 / 0.0913
[15/100][0/93]	Loss_D: 1.8772	Loss_G: 3.7269	D(x): 0.9198	D(G(z)): 0.7731 / 0.0543
[15/100][50/93]	Loss_D: 0.8537	Loss_G: 3.8992	D(x): 0.8302	D(G(z)): 0.4356 / 0.0287
[16/100][0/93]	Loss_D: 1.7657	Loss_G: 2.0543	D(x): 0.5410	D(G(z)): 0.4627 / 0.1847
[16/100][50/93]	Loss_D: 0.8047	Loss_G: 2.6118	D(x): 0.6718	D(G(z)): 0.2821 / 0.1000
[17/100][0/93]	Loss_D: 2.4818	Loss_G: 2.8336	D(x): 0.8645	D(G(z)): 0.7823 / 0.1359
[17/100][50/93]	Loss_D: 1.0354	Loss_G: 2.2132	D(x): 0.5708	D(G(z)): 0.2789 / 0.1293
[18/100][0/93]	Loss_D: 1.7317	Loss_G: 3.6428	D(x): 0.7533	D(G(z)): 0.6480 / 0.0399
[18/100][50/93]	Loss_D: 0.8516	Loss_G: 2.2946	D(x): 0.7705	D(G(z)): 0.4060 / 0.1227
[19/100][0/93]	Loss_D: 1.1431	Loss_G: 1.9733	D(x): 0.6242	D(G(z)): 0.3522 / 0.2112
[19/100][50/93]	Loss_D: 1.1567	Loss_G: 6.1132	D(x): 0.7866	D(G(z)): 0.5147 / 0.0077
[20/100][0/93]	Loss_D: 1.4299	Loss_G: 2.2622	D(x): 0.4541	D(G(z)): 0.1455 / 0.1978
[20/100][50/93]	Loss_D: 0.9409	Loss_G: 2.4589	D(x): 0.6924	D(G(z)): 0.3783 / 0.1071
[21/100][0/93]	Loss_D: 1.8318	Loss_G: 3.8064	D(x): 0.8431	D(G(z)): 0.6862 / 0.0325
[21/100][50/93]	Loss_D: 1.2066	Loss_G: 2.1727	D(x): 0.5239	D(G(z)): 0.3132 / 0.1538
[22/100][0/93]	Loss_D: 0.5714	Loss_G: 4.0216	D(x): 0.8381	D(G(z)): 0.2414 / 0.0381
[22/100][50/93]	Loss_D: 0.8990	Loss_G: 2.8960	D(x): 0.5003	D(G(z)): 0.0758 / 0.1195
[23/100][0/93]	Loss_D: 2.2629	Loss_G: 3.9044	D(x): 0.5930	D(G(z)): 0.4464 / 0.0491
[23/100][50/93]	Loss_D: 1.4069	Loss_G: 3.4361	D(x): 0.8317	D(G(z)): 0.6443 / 0.0545
[24/100][0/93]	Loss_D: 2.5020	Loss_G: 1.8896	D(x): 0.2936	D(G(z)): 0.0756 / 0.2724
[24/100][50/93]	Loss_D: 1.9388	Loss_G: 4.3432	D(x): 0.7516	D(G(z)): 0.7009 / 0.0217
[25/100][0/93]	Loss_D: 1.5774	Loss_G: 2.5140	D(x): 0.3555	D(G(z)): 0.1111 / 0.1388
[25/100][50/93]	Loss_D: 0.8796	Loss_G: 5.1617	D(x): 0.8614	D(G(z)): 0.4593 / 0.0152
[26/100][0/93]	Loss_D: 1.2471	Loss_G: 2.3314	D(x): 0.7924	D(G(z)): 0.5335 / 0.1419
[26/100][50/93]	Loss_D: 1.0241	Loss_G: 2.6725	D(x): 0.5944	D(G(z)): 0.2025 / 0.1230
[27/100][0/93]	Loss_D: 2.7785	Loss_G: 2.4545	D(x): 0.9845	D(G(z)): 0.8878 / 0.1589
[27/100][50/93]	Loss_D: 1.7111	Loss_G: 3.3541	D(x): 0.6441	D(G(z)): 0.6389 / 0.0517
[28/100][0/93]	Loss_D: 2.1377	Loss_G: 4.4104	D(x): 0.9630	D(G(z)): 0.8030 / 0.0325
[28/100][50/93]	Loss_D: 0.8493	Loss_G: 2.6771	D(x): 0.6414	D(G(z)): 0.2139 / 0.0940
[29/100][0/93]	Loss_D: 0.5737	Loss_G: 2.3015	D(x): 0.7383	D(G(z)): 0.1876 / 0.1303
[29/100][50/93]	Loss_D: 0.5571	Loss_G: 3.0367	D(x): 0.7029	D(G(z)): 0.1287 / 0.0823
[30/100][0/93]	Loss_D: 1.0021	Loss_G: 1.4303	D(x): 0.5547	D(G(z)): 0.2125 / 0.3122
[30/100][50/93]	Loss_D: 1.6175	Loss_G: 1.2539	D(x): 0.3822	D(G(z)): 0.2951 / 0.3275
[31/100][0/93]	Loss_D: 2.7864	Loss_G: 2.2018	D(x): 0.1919	D(G(z)): 0.1527 / 0.2177
[31/100][50/93]	Loss_D: 1.0359	Loss_G: 2.5108	D(x): 0.6575	D(G(z)): 0.3703 / 0.1115
[32/100][0/93]	Loss_D: 1.9158	Loss_G: 2.6807	D(x): 0.9170	D(G(z)): 0.7546 / 0.1053
[32/100][50/93]	Loss_D: 1.9358	Loss_G: 1.8482	D(x): 0.4792	D(G(z)): 0.5489 / 0.2008
[33/100][0/93]	Loss_D: 2.3778	Loss_G: 2.5966	D(x): 0.3007	D(G(z)): 0.2455 / 0.1762
[33/100][50/93]	Loss_D: 0.7805	Loss_G: 1.6653	D(x): 0.6461	D(G(z)): 0.2187 / 0.2363
[34/100][0/93]	Loss_D: 1.1617	Loss_G: 2.4731	D(x): 0.5387	D(G(z)): 0.2090 / 0.1410
[34/100][50/93]	Loss_D: 1.0139	Loss_G: 1.8626	D(x): 0.4932	D(G(z)): 0.1352 / 0.2212
[35/100][0/93]	Loss_D: 1.2819	Loss_G: 2.4171	D(x): 0.6179	D(G(z)): 0.3117 / 0.2100
[35/100][50/93]	Loss_D: 1.0666	Loss_G: 2.3973	D(x): 0.6128	D(G(z)): 0.3099 / 0.1430
[36/100][0/93]	Loss_D: 1.5554	Loss_G: 0.6270	D(x): 0.3460	D(G(z)): 0.1852 / 0.5891
[36/100][50/93]	Loss_D: 0.7867	Loss_G: 3.6784	D(x): 0.7647	D(G(z)): 0.3462 / 0.0372
[37/100][0/93]	Loss_D: 1.3258	Loss_G: 2.3225	D(x): 0.3627	D(G(z)): 0.0459 / 0.1876
[37/100][50/93]	Loss_D: 0.9710	Loss_G: 2.4323	D(x): 0.6721	D(G(z)): 0.3169 / 0.1281
[38/100][0/93]	Loss_D: 0.9660	Loss_G: 3.1588	D(x): 0.6114	D(G(z)): 0.0931 / 0.1121
[38/100][50/93]	Loss_D: 1.2708	Loss_G: 2.6257	D(x): 0.7058	D(G(z)): 0.4844 / 0.1136
[39/100][0/93]	Loss_D: 2.3114	Loss_G: 3.5862	D(x): 0.8411	D(G(z)): 0.7433 / 0.0716
[39/100][50/93]	Loss_D: 1.0918	Loss_G: 2.1977	D(x): 0.6381	D(G(z)): 0.3201 / 0.1630
[40/100][0/93]	Loss_D: 1.6300	Loss_G: 2.1574	D(x): 0.5619	D(G(z)): 0.3563 / 0.2015
[40/100][50/93]	Loss_D: 1.2009	Loss_G: 1.7661	D(x): 0.4334	D(G(z)): 0.1192 / 0.2488
[41/100][0/93]	Loss_D: 1.9473	Loss_G: 5.1857	D(x): 0.9311	D(G(z)): 0.7065 / 0.0245
[41/100][50/93]	Loss_D: 0.8458	Loss_G: 3.7414	D(x): 0.7786	D(G(z)): 0.3648 / 0.0382
[42/100][0/93]	Loss_D: 2.3187	Loss_G: 4.2274	D(x): 0.9027	D(G(z)): 0.7909 / 0.0501
[42/100][50/93]	Loss_D: 0.5929	Loss_G: 2.6997	D(x): 0.7687	D(G(z)): 0.2301 / 0.1085
[43/100][0/93]	Loss_D: 0.8762	Loss_G: 2.9043	D(x): 0.5853	D(G(z)): 0.0705 / 0.1188
[43/100][50/93]	Loss_D: 1.0880	Loss_G: 3.1332	D(x): 0.8245	D(G(z)): 0.4706 / 0.0732
[44/100][0/93]	Loss_D: 0.6111	Loss_G: 3.6152	D(x): 0.7505	D(G(z)): 0.1722 / 0.0756
[44/100][50/93]	Loss_D: 1.2064	Loss_G: 3.7893	D(x): 0.8448	D(G(z)): 0.5339 / 0.0545
[45/100][0/93]	Loss_D: 1.3437	Loss_G: 4.3056	D(x): 0.8959	D(G(z)): 0.5562 / 0.0284
[45/100][50/93]	Loss_D: 0.6329	Loss_G: 2.6639	D(x): 0.7889	D(G(z)): 0.2604 / 0.1028
[46/100][0/93]	Loss_D: 0.7112	Loss_G: 2.7402	D(x): 0.7684	D(G(z)): 0.2652 / 0.0961
[46/100][50/93]	Loss_D: 1.0318	Loss_G: 1.6051	D(x): 0.5776	D(G(z)): 0.2185 / 0.2891
[47/100][0/93]	Loss_D: 2.1587	Loss_G: 0.6412	D(x): 0.2336	D(G(z)): 0.0759 / 0.6635
[47/100][50/93]	Loss_D: 1.4636	Loss_G: 0.9999	D(x): 0.3244	D(G(z)): 0.0739 / 0.4295
[48/100][0/93]	Loss_D: 2.2659	Loss_G: 3.4730	D(x): 0.9436	D(G(z)): 0.7613 / 0.1241
[48/100][50/93]	Loss_D: 0.8932	Loss_G: 3.8968	D(x): 0.9163	D(G(z)): 0.4613 / 0.0435
[49/100][0/93]	Loss_D: 1.6099	Loss_G: 3.5561	D(x): 0.7170	D(G(z)): 0.3702 / 0.1814
[49/100][50/93]	Loss_D: 1.1243	Loss_G: 2.0715	D(x): 0.5380	D(G(z)): 0.2477 / 0.1753
[50/100][0/93]	Loss_D: 1.4756	Loss_G: 3.7362	D(x): 0.9562	D(G(z)): 0.6335 / 0.0529
[50/100][50/93]	Loss_D: 0.8148	Loss_G: 3.3667	D(x): 0.8174	D(G(z)): 0.3246 / 0.0743
[51/100][0/93]	Loss_D: 1.0757	Loss_G: 2.2240	D(x): 0.6342	D(G(z)): 0.1822 / 0.2005
[51/100][50/93]	Loss_D: 0.6628	Loss_G: 3.3086	D(x): 0.8101	D(G(z)): 0.2859 / 0.0554
[52/100][0/93]	Loss_D: 1.7151	Loss_G: 5.3445	D(x): 0.8640	D(G(z)): 0.5757 / 0.0234
[52/100][50/93]	Loss_D: 0.6499	Loss_G: 3.1340	D(x): 0.8274	D(G(z)): 0.2812 / 0.0689
[53/100][0/93]	Loss_D: 3.9578	Loss_G: 5.0066	D(x): 0.9872	D(G(z)): 0.8885 / 0.0627
[53/100][50/93]	Loss_D: 0.6654	Loss_G: 2.3179	D(x): 0.6457	D(G(z)): 0.0873 / 0.1510
[54/100][0/93]	Loss_D: 1.2205	Loss_G: 6.3177	D(x): 0.9503	D(G(z)): 0.4954 / 0.0073
[54/100][50/93]	Loss_D: 0.4608	Loss_G: 3.7845	D(x): 0.8039	D(G(z)): 0.1609 / 0.0434
[55/100][0/93]	Loss_D: 0.6900	Loss_G: 2.2938	D(x): 0.6360	D(G(z)): 0.1005 / 0.1988
[55/100][50/93]	Loss_D: 0.3855	Loss_G: 4.1904	D(x): 0.8773	D(G(z)): 0.1713 / 0.0260
[56/100][0/93]	Loss_D: 1.3378	Loss_G: 5.4624	D(x): 0.8578	D(G(z)): 0.3810 / 0.0358
[56/100][50/93]	Loss_D: 0.3627	Loss_G: 3.2429	D(x): 0.8475	D(G(z)): 0.1269 / 0.0777
[57/100][0/93]	Loss_D: 1.3435	Loss_G: 6.4843	D(x): 0.9568	D(G(z)): 0.6105 / 0.0045
[57/100][50/93]	Loss_D: 0.3627	Loss_G: 3.7049	D(x): 0.8852	D(G(z)): 0.1759 / 0.0428
[58/100][0/93]	Loss_D: 0.8754	Loss_G: 2.6383	D(x): 0.6550	D(G(z)): 0.0860 / 0.2380
[58/100][50/93]	Loss_D: 0.7884	Loss_G: 4.4666	D(x): 0.8836	D(G(z)): 0.4134 / 0.0215
[59/100][0/93]	Loss_D: 1.0084	Loss_G: 7.4436	D(x): 0.9613	D(G(z)): 0.5033 / 0.0013
[59/100][50/93]	Loss_D: 0.6462	Loss_G: 3.3919	D(x): 0.6255	D(G(z)): 0.0354 / 0.0929
[60/100][0/93]	Loss_D: 2.4752	Loss_G: 6.4355	D(x): 0.9544	D(G(z)): 0.7417 / 0.0178
[60/100][50/93]	Loss_D: 0.6469	Loss_G: 3.2756	D(x): 0.7684	D(G(z)): 0.2437 / 0.0640
[61/100][0/93]	Loss_D: 1.3451	Loss_G: 6.0108	D(x): 0.9857	D(G(z)): 0.5472 / 0.0204
[61/100][50/93]	Loss_D: 0.2907	Loss_G: 3.8346	D(x): 0.8662	D(G(z)): 0.1117 / 0.0411
[62/100][0/93]	Loss_D: 3.7565	Loss_G: 8.8503	D(x): 0.9860	D(G(z)): 0.8801 / 0.0046
[62/100][50/93]	Loss_D: 0.3811	Loss_G: 3.5771	D(x): 0.8389	D(G(z)): 0.1373 / 0.0542
[63/100][0/93]	Loss_D: 2.0750	Loss_G: 8.8962	D(x): 0.9898	D(G(z)): 0.7575 / 0.0008
[63/100][50/93]	Loss_D: 0.3268	Loss_G: 4.7024	D(x): 0.8273	D(G(z)): 0.0499 / 0.0258
[64/100][0/93]	Loss_D: 0.8053	Loss_G: 6.4134	D(x): 0.9731	D(G(z)): 0.3950 / 0.0057
[64/100][50/93]	Loss_D: 0.3953	Loss_G: 5.7989	D(x): 0.9553	D(G(z)): 0.2434 / 0.0072
[65/100][0/93]	Loss_D: 2.7963	Loss_G: 7.5172	D(x): 0.9988	D(G(z)): 0.8158 / 0.0064
[65/100][50/93]	Loss_D: 0.3890	Loss_G: 4.2586	D(x): 0.8875	D(G(z)): 0.1830 / 0.0290
[66/100][0/93]	Loss_D: 2.0152	Loss_G: 8.0582	D(x): 0.9820	D(G(z)): 0.6841 / 0.0008
[66/100][50/93]	Loss_D: 0.4067	Loss_G: 3.9679	D(x): 0.7515	D(G(z)): 0.0254 / 0.0407
[67/100][0/93]	Loss_D: 0.9672	Loss_G: 5.2643	D(x): 0.6855	D(G(z)): 0.0499 / 0.0879
[67/100][50/93]	Loss_D: 0.5831	Loss_G: 3.8438	D(x): 0.7892	D(G(z)): 0.2042 / 0.0350
[68/100][0/93]	Loss_D: 4.1007	Loss_G: 9.8992	D(x): 0.9981	D(G(z)): 0.8449 / 0.0077
[68/100][50/93]	Loss_D: 0.1625	Loss_G: 4.5259	D(x): 0.9166	D(G(z)): 0.0610 / 0.0272
[69/100][0/93]	Loss_D: 0.1479	Loss_G: 4.5000	D(x): 0.9411	D(G(z)): 0.0715 / 0.0278
[69/100][50/93]	Loss_D: 0.1519	Loss_G: 6.2512	D(x): 0.9672	D(G(z)): 0.0954 / 0.0048
[70/100][0/93]	Loss_D: 0.8392	Loss_G: 5.8756	D(x): 0.9786	D(G(z)): 0.3902 / 0.0217
[70/100][50/93]	Loss_D: 0.3707	Loss_G: 4.4626	D(x): 0.7515	D(G(z)): 0.0334 / 0.0347
[71/100][0/93]	Loss_D: 0.8310	Loss_G: 8.1234	D(x): 0.9455	D(G(z)): 0.4131 / 0.0010
[71/100][50/93]	Loss_D: 0.4376	Loss_G: 2.1381	D(x): 0.7533	D(G(z)): 0.0536 / 0.2369
[72/100][0/93]	Loss_D: 1.0431	Loss_G: 3.2314	D(x): 0.5088	D(G(z)): 0.0038 / 0.2200
[72/100][50/93]	Loss_D: 0.5805	Loss_G: 7.3069	D(x): 0.9804	D(G(z)): 0.2963 / 0.0019
[73/100][0/93]	Loss_D: 2.8700	Loss_G: 9.9494	D(x): 0.9496	D(G(z)): 0.6259 / 0.0006
[73/100][50/93]	Loss_D: 0.3785	Loss_G: 3.5431	D(x): 0.8275	D(G(z)): 0.0818 / 0.0928
[74/100][0/93]	Loss_D: 1.6961	Loss_G: 7.5021	D(x): 0.9969	D(G(z)): 0.6196 / 0.0139
[74/100][50/93]	Loss_D: 0.2629	Loss_G: 3.7404	D(x): 0.8146	D(G(z)): 0.0166 / 0.0624
[75/100][0/93]	Loss_D: 0.1206	Loss_G: 4.8163	D(x): 0.9432	D(G(z)): 0.0482 / 0.0272
[75/100][50/93]	Loss_D: 0.3696	Loss_G: 5.4884	D(x): 0.9040	D(G(z)): 0.1696 / 0.0091
[76/100][0/93]	Loss_D: 0.8842	Loss_G: 10.1655	D(x): 0.9953	D(G(z)): 0.3930 / 0.0002
[76/100][50/93]	Loss_D: 0.2604	Loss_G: 3.9907	D(x): 0.8391	D(G(z)): 0.0429 / 0.0650
[77/100][0/93]	Loss_D: 1.5262	Loss_G: 10.1721	D(x): 0.9886	D(G(z)): 0.5355 / 0.0003
[77/100][50/93]	Loss_D: 0.3340	Loss_G: 4.1095	D(x): 0.8047	D(G(z)): 0.0317 / 0.0550
[78/100][0/93]	Loss_D: 3.4326	Loss_G: 10.2262	D(x): 0.9872	D(G(z)): 0.6616 / 0.0016
[78/100][50/93]	Loss_D: 0.1676	Loss_G: 5.5463	D(x): 0.9035	D(G(z)): 0.0413 / 0.0167
[79/100][0/93]	Loss_D: 0.1273	Loss_G: 5.5658	D(x): 0.9771	D(G(z)): 0.0895 / 0.0100
[79/100][50/93]	Loss_D: 0.1196	Loss_G: 6.1986	D(x): 0.9945	D(G(z)): 0.0941 / 0.0047
[80/100][0/93]	Loss_D: 2.3309	Loss_G: 13.2917	D(x): 0.9867	D(G(z)): 0.6010 / 0.0000
[80/100][50/93]	Loss_D: 0.4412	Loss_G: 3.9795	D(x): 0.7322	D(G(z)): 0.0339 / 0.1024
[81/100][0/93]	Loss_D: 4.3847	Loss_G: 11.8461	D(x): 0.9607	D(G(z)): 0.8149 / 0.0001
[81/100][50/93]	Loss_D: 0.7522	Loss_G: 3.2038	D(x): 0.6605	D(G(z)): 0.1330 / 0.1002
[82/100][0/93]	Loss_D: 2.1682	Loss_G: 7.2129	D(x): 0.9370	D(G(z)): 0.6094 / 0.0194
[82/100][50/93]	Loss_D: 0.3350	Loss_G: 4.8098	D(x): 0.8828	D(G(z)): 0.1488 / 0.0231
[83/100][0/93]	Loss_D: 1.7104	Loss_G: 6.4060	D(x): 0.9952	D(G(z)): 0.6291 / 0.0207
[83/100][50/93]	Loss_D: 0.2840	Loss_G: 5.2562	D(x): 0.8959	D(G(z)): 0.0825 / 0.0171
[84/100][0/93]	Loss_D: 0.1664	Loss_G: 5.1611	D(x): 0.9804	D(G(z)): 0.1151 / 0.0281
[84/100][50/93]	Loss_D: 0.3435	Loss_G: 5.3042	D(x): 0.8096	D(G(z)): 0.0295 / 0.0210
[85/100][0/93]	Loss_D: 0.5872	Loss_G: 7.6036	D(x): 0.9728	D(G(z)): 0.2683 / 0.0049
[85/100][50/93]	Loss_D: 0.0818	Loss_G: 5.2028	D(x): 0.9696	D(G(z)): 0.0446 / 0.0138
[86/100][0/93]	Loss_D: 3.9731	Loss_G: 11.9662	D(x): 0.9964	D(G(z)): 0.8531 / 0.0018
[86/100][50/93]	Loss_D: 0.0825	Loss_G: 5.6252	D(x): 0.9711	D(G(z)): 0.0454 / 0.0132
[87/100][0/93]	Loss_D: 0.0703	Loss_G: 5.5951	D(x): 0.9944	D(G(z)): 0.0511 / 0.0137
[87/100][50/93]	Loss_D: 0.2257	Loss_G: 6.7898	D(x): 0.9840	D(G(z)): 0.1499 / 0.0056
[88/100][0/93]	Loss_D: 4.5821	Loss_G: 12.6655	D(x): 0.9996	D(G(z)): 0.8642 / 0.0002
[88/100][50/93]	Loss_D: 0.4079	Loss_G: 5.5097	D(x): 0.9079	D(G(z)): 0.1865 / 0.0104
[89/100][0/93]	Loss_D: 0.8886	Loss_G: 5.3182	D(x): 0.8536	D(G(z)): 0.2253 / 0.1090
[89/100][50/93]	Loss_D: 0.3711	Loss_G: 5.0237	D(x): 0.7786	D(G(z)): 0.0333 / 0.0704
[90/100][0/93]	Loss_D: 0.7508	Loss_G: 9.3780	D(x): 0.9836	D(G(z)): 0.4002 / 0.0003
[90/100][50/93]	Loss_D: 0.1053	Loss_G: 5.6548	D(x): 0.9250	D(G(z)): 0.0147 / 0.0187
[91/100][0/93]	Loss_D: 1.4961	Loss_G: 11.8283	D(x): 0.9697	D(G(z)): 0.4693 / 0.0004
[91/100][50/93]	Loss_D: 0.1997	Loss_G: 6.9225	D(x): 0.9877	D(G(z)): 0.1128 / 0.0048
[92/100][0/93]	Loss_D: 2.2053	Loss_G: 12.7526	D(x): 0.9765	D(G(z)): 0.5430 / 0.0007
[92/100][50/93]	Loss_D: 0.1007	Loss_G: 6.0625	D(x): 0.9817	D(G(z)): 0.0683 / 0.0095
[93/100][0/93]	Loss_D: 0.0461	Loss_G: 6.3525	D(x): 0.9765	D(G(z)): 0.0179 / 0.0075
[93/100][50/93]	Loss_D: 0.0465	Loss_G: 6.9129	D(x): 0.9825	D(G(z)): 0.0253 / 0.0046
[94/100][0/93]	Loss_D: 0.5635	Loss_G: 10.4235	D(x): 0.9961	D(G(z)): 0.2810 / 0.0001
[94/100][50/93]	Loss_D: 0.1372	Loss_G: 6.8188	D(x): 0.9875	D(G(z)): 0.0846 / 0.0053
[95/100][0/93]	Loss_D: 0.4655	Loss_G: 11.1938	D(x): 0.9983	D(G(z)): 0.2945 / 0.0001
[95/100][50/93]	Loss_D: 0.9106	Loss_G: 2.9039	D(x): 0.5687	D(G(z)): 0.0029 / 0.3533
[96/100][0/93]	Loss_D: 0.5251	Loss_G: 5.7969	D(x): 0.7349	D(G(z)): 0.0027 / 0.0485
[96/100][50/93]	Loss_D: 0.5280	Loss_G: 8.6875	D(x): 0.9989	D(G(z)): 0.2794 / 0.0013
[97/100][0/93]	Loss_D: 1.6312	Loss_G: 8.7170	D(x): 0.6171	D(G(z)): 0.0312 / 0.3122
[97/100][50/93]	Loss_D: 0.1625	Loss_G: 5.9454	D(x): 0.9163	D(G(z)): 0.0399 / 0.0114
[98/100][0/93]	Loss_D: 5.3428	Loss_G: 12.0925	D(x): 0.9905	D(G(z)): 0.6376 / 0.0457
[98/100][50/93]	Loss_D: 0.1264	Loss_G: 5.8150	D(x): 0.9373	D(G(z)): 0.0381 / 0.0119
[99/100][0/93]	Loss_D: 0.4290	Loss_G: 9.4335	D(x): 0.9799	D(G(z)): 0.2182 / 0.0023
[99/100][50/93]	Loss_D: 0.0779	Loss_G: 7.8404	D(x): 0.9801	D(G(z)): 0.0307 / 0.0023
In [11]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()
In [12]:
#%%capture
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())
Out[12]:
In [13]:
# Grab a batch of real images from the dataloader
real_batch = next(iter(dataloader))

# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()
In [ ]: